DeepONet: Learning diffusivity (m) to solution (u) map for the Poisson problem¶
Data is located in ../data directory, and key data of our interest is in Poisson_samples.npz file.
On data¶
The Dropbox folder NeuralOperator_Survey_Shared_Data_March2025 contains the key data to reproduce the results in the survey paper.
If you did not generate data by running survey_work/problems/poisson/Poisson.ipynb, consider copying the contents of dropbox folder NeuralOperator_Survey_Shared_Data_March2025/survey_work/problems/poisson/data/ into survey_work/problems/poisson/data/ before running this notebook.
Results¶
Below shows the neural operator prediction for different samples of test input.
In [1]:
import sys
import os
import torch
import numpy as np
src_path = "../../src/"
sys.path.append(src_path + 'plotting/')
from field_plot import field_plot
from plot_loss import plot_loss
sys.path.append(src_path + 'data/')
from dataMethods import DataProcessor
sys.path.append(src_path + 'nn/deeponet/')
sys.path.append(src_path + 'nn/mlp/') # need this here so that DeepONet can be imported (it imports MLP)
from torch_deeponet import DeepONet
import uq
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}'
# set seed
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
from sklearn.metrics import roc_auc_score, roc_curve
In [2]:
data_folder = '../../../autodl-tmp/data/'
results_dir = data_folder
if not os.path.exists(results_dir):
os.makedirs(results_dir)
Load data¶
In [3]:
num_train = 3500
num_test = 1000
num_inp_fn_points = 2601 # number of grid points for the input function
num_out_fn_points = 2601 # number of evaluations points for the output function
num_Y_components = 1 # scalar field
num_tr_outputs = 100 # number of outputs from the trunk network before they are multiplied
num_br_outputs = 100 # number of outputs from the branch and trunk networks before they are multiplied
out_coordinate_dimension = 2 # domain for output function is 2D
# training hyperparameters
batch_size = 20
epochs = 1000
lr = 1.0e-3
act_fn = torch.relu
data_prefix = 'Poisson'
data = DataProcessor(data_folder + data_prefix + '_samples_no-ood.npz', num_train, num_test, num_inp_fn_points, num_out_fn_points, num_Y_components)
train_data = {'X_train': data.X_train, 'X_trunk': data.X_trunk, 'Y_train': data.Y_train}
test_data = {'X_train': data.X_test, 'X_trunk': data.X_trunk, 'Y_train': data.Y_test}
print('X_train:',data.X_train.shape)
print('Y_train:',data.Y_train.shape)
print('X_test:',data.X_test.shape)
print('Y_test:',data.Y_test.shape)
print('X_trunk:',data.X_trunk.shape)
X_train: (3500, 2601) Y_train: (3500, 2601) X_test: (1000, 2601) Y_test: (1000, 2601) X_trunk: (2601, 2)
Create model and train the network¶
In [4]:
num_layers = 3
num_neurons = 64
model_save_path = results_dir + 'DeepONet/'
model_save_file = model_save_path + 'model.pkl'
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
model = DeepONet(num_layers, num_neurons, act_fn, num_br_outputs, \
num_tr_outputs, num_inp_fn_points, \
out_coordinate_dimension, num_Y_components,\
save_file = model_save_file)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Number of trainable parameters: {}'.format(trainable_params))
Using device: cuda Number of trainable parameters: 188041
In [5]:
# save the data and info
data_to_save = data.get_data_to_save()
model_metadata = { 'data': data_to_save, \
'num_train': num_train, \
'num_test': num_test, \
'num_inp_fn_points': num_inp_fn_points, \
'num_out_fn_points': num_out_fn_points, \
'num_Y_components': num_Y_components, \
'num_tr_outputs': num_tr_outputs, \
'num_br_outputs': num_br_outputs, \
'out_coordinate_dimension': out_coordinate_dimension, \
'num_layers': num_layers, \
'num_neurons': num_neurons, \
'epochs': epochs, \
'batch_size': batch_size, \
'lr': lr}
# attach it to the model
model.metadata = model_metadata
In [6]:
# Train
model.train(train_data, test_data, batch_size=batch_size, \
epochs = epochs, lr = lr, \
save_model = True, save_epoch = 100)
-------------------------------------------------- Starting training with 188041 trainable parameters... -------------------------------------------------- -------------------------------------------------- Epoch: 1, Train Loss (l2 squared): 4.326e-01, Test Loss (l2 squared): 2.641e-01, Time (sec): 0.475 -------------------------------------------------- -------------------------------------------------- Epoch: 100, Train Loss (l2 squared): 2.265e-02, Test Loss (l2 squared): 2.927e-02, Time (sec): 0.268 -------------------------------------------------- -------------------------------------------------- Model parameters saved at epoch 100 -------------------------------------------------- -------------------------------------------------- Epoch: 200, Train Loss (l2 squared): 1.672e-02, Test Loss (l2 squared): 1.715e-02, Time (sec): 0.410 -------------------------------------------------- -------------------------------------------------- Model parameters saved at epoch 200 -------------------------------------------------- -------------------------------------------------- Epoch: 300, Train Loss (l2 squared): 1.346e-02, Test Loss (l2 squared): 1.580e-02, Time (sec): 0.390 -------------------------------------------------- -------------------------------------------------- Model parameters saved at epoch 300 -------------------------------------------------- -------------------------------------------------- Epoch: 400, Train Loss (l2 squared): 1.162e-02, Test Loss (l2 squared): 1.408e-02, Time (sec): 0.350 -------------------------------------------------- -------------------------------------------------- Model parameters saved at epoch 400 -------------------------------------------------- -------------------------------------------------- Epoch: 500, Train Loss (l2 squared): 1.111e-02, Test Loss (l2 squared): 1.359e-02, Time (sec): 0.275 -------------------------------------------------- -------------------------------------------------- Model parameters saved at epoch 500 -------------------------------------------------- -------------------------------------------------- Epoch: 600, Train Loss (l2 squared): 1.086e-02, Test Loss (l2 squared): 1.343e-02, Time (sec): 0.302 -------------------------------------------------- -------------------------------------------------- Model parameters saved at epoch 600 -------------------------------------------------- -------------------------------------------------- Epoch: 700, Train Loss (l2 squared): 1.069e-02, Test Loss (l2 squared): 1.338e-02, Time (sec): 0.403 -------------------------------------------------- -------------------------------------------------- Model parameters saved at epoch 700 -------------------------------------------------- -------------------------------------------------- Epoch: 800, Train Loss (l2 squared): 1.060e-02, Test Loss (l2 squared): 1.337e-02, Time (sec): 0.365 -------------------------------------------------- -------------------------------------------------- Model parameters saved at epoch 800 -------------------------------------------------- -------------------------------------------------- Epoch: 900, Train Loss (l2 squared): 1.056e-02, Test Loss (l2 squared): 1.331e-02, Time (sec): 0.433 -------------------------------------------------- -------------------------------------------------- Model parameters saved at epoch 900 -------------------------------------------------- -------------------------------------------------- Epoch: 1000, Train Loss (l2 squared): 1.054e-02, Test Loss (l2 squared): 1.331e-02, Time (sec): 0.495 -------------------------------------------------- -------------------------------------------------- Model parameters saved at epoch 1000 -------------------------------------------------- -------------------------------------------------- Train time: 424.790, Epochs: 1000, Batch Size: 20, Final Train Loss (l2 squared): 1.054e-02, Final Test Loss (l2 squared): 1.331e-02 --------------------------------------------------
In [7]:
plot_loss( model.train_loss_log[:, 0], \
model.test_loss_log[:, 0], \
fs = 14, lw = 2, \
savefile = results_dir+'loss_his.png', \
figsize = [6,6])
Test and plot the output of network¶
In [8]:
# load the model
model = torch.load(model_save_file, weights_only=False)
In [9]:
Y_test = test_data['Y_train']
Y_test_pred = model.predict(test_data['X_train'], test_data['X_trunk']).detach().cpu().numpy()
print('test_out shape: {}, test_pred shape: {}'.format(Y_test.shape, Y_test_pred.shape))
error = np.linalg.norm(Y_test - Y_test_pred, axis = 1)/np.linalg.norm(Y_test, axis = 1)
print('Num tests: {:5d}, Mean Loss (rel l2): {:.3e}, Std Loss (rel l2): {:.3e}'.format(num_test, np.mean(error), np.std(error)))
test_out shape: (1000, 2601), test_pred shape: (1000, 2601) Num tests: 1000, Mean Loss (rel l2): 1.420e-01, Std Loss (rel l2): 5.056e-02
In [10]:
def apply_dirichlet_bc(u, bc_value, bc_node_ids):
u[bc_node_ids] = bc_value
return u
In [11]:
rows, cols = 4, 4
fs = 20
fig, axs = plt.subplots(rows, cols, figsize=(16, 13))
decode = True
apply_dirichlet_bc_flag = True
# row: m, u_true, u_pred, u_diff
u_tags = [r'$m$', r'$u_{true}$', r'$u_{pred}$', r'$u_{true} - u_{pred}$']
cmaps = ['jet', 'viridis', 'viridis', 'hot']
nodes = data.X_trunk
# randomly choose rows number of samples
i_choices = np.random.choice(num_test, rows, replace=False)
for i in range(rows):
i_plot = i_choices[i]
i_pred = Y_test_pred[i_plot]
i_truth = Y_test[i_plot]
i_m_test = data.X_test[i_plot]
if decode:
i_pred = data.decoder_Y(i_pred)
i_truth = data.decoder_Y(i_truth)
i_m_test = data.decoder_X(i_m_test)
if apply_dirichlet_bc_flag:
i_pred = apply_dirichlet_bc(i_pred, 0.0, data.u_mesh_dirichlet_boundary_nodes)
# verify for i_truth
if np.abs(i_truth[data.u_mesh_dirichlet_boundary_nodes]).max() > 1.0e-9:
print('Warning: Dirichlet BC not applied to i_truth. Err : {}'.format(np.abs(i_truth[data.u_mesh_dirichlet_boundary_nodes]).max()))
i_diff = i_pred - i_truth
i_diff_norm = np.linalg.norm(i_diff) / np.linalg.norm(i_truth)
print('i_plot = {:5d}, error (rel l2): {:.3e}'.format(i_plot, i_diff_norm))
uvec = [i_m_test, i_truth, i_pred, i_diff]
for j in range(cols):
cbar = field_plot(axs[i,j], uvec[j], nodes, cmap = cmaps[j])
divider = make_axes_locatable(axs[i,j])
cax = divider.append_axes('right', size='8%', pad=0.03)
cax.tick_params(labelsize=fs)
if j == 0 or j == cols - 1:
# format cbar ticks
kfmt = lambda x, pos: "{:g}".format(x)
cbar = fig.colorbar(cbar, cax=cax, orientation='vertical', format = kfmt)
else:
cbar = fig.colorbar(cbar, cax=cax, orientation='vertical')
if i == 0 and j < cols - 1:
axs[i,j].set_title(u_tags[j], fontsize=fs)
if j == cols - 1:
err_str = 'err (rel l2): {:.3f}%'.format(i_diff_norm*100)
if i == 0:
err_str = u_tags[j] + '\n' + err_str
axs[i,j].set_title(err_str, fontsize=fs)
axs[i,j].axis('off')
fig.tight_layout()
fig.suptitle('Poisson problem: Compare neural operator predictions ({})'.format(model.name), fontsize=1.25*fs, y=1.025)
fig.savefig(results_dir+'neural_operator_prediction_comparison.png', bbox_inches='tight')
plt.show()
i_plot = 993, error (rel l2): 1.045e-02 i_plot = 859, error (rel l2): 1.557e-02 i_plot = 298, error (rel l2): 9.743e-03 i_plot = 553, error (rel l2): 1.436e-02
HMC Uncertainty Quantification¶
In [12]:
model = torch.load(model_save_file, weights_only=False)
model.to(device)
# Use a subset of test data for HMC (full dataet may be too slow)
num_hmc_samples_data = 1000 # number of data points to use
hmc_indices = np.random.choice(num_test, num_hmc_samples_data, replace=False)
x_branch_hmc = test_data['X_train'][hmc_indices]
x_trunk_hmc = test_data['X_trunk']
y_hmc = test_data['Y_train'][hmc_indices]
print(f"Using {num_hmc_samples_data} test samples for HMC")
print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")
# Initialize from current model parameters
flat0 = uq.pack_params(model).to(device)
print(f"Initial parameter vector shape: {flat0.shape}")
# Create log probability function
log_prob = uq.make_log_prob_fn(model, x_branch_hmc, x_trunk_hmc, y_hmc,
noise_std=0.05, prior_std=1.0)
# Adaptive HMC settings
hmc_num_samples = 2000
hmc_burn_in = 200 # Increased burn-in for adaptation
hmc_adapt_steps = 150 # Steps to adapt step size
hmc_initial_step_size = 1e-7
hmc_leapfrog_steps = 20
hmc_target_accept = 0.75 # Target acceptance rate (65-80% is optimal)
print(f"\nAdaptive HMC Settings:")
print(f" num_samples: {hmc_num_samples}")
print(f" burn_in: {hmc_burn_in}")
print(f" adapt_steps: {hmc_adapt_steps}")
print(f" initial_step_size: {hmc_initial_step_size}")
print(f" leapfrog_steps: {hmc_leapfrog_steps}")
print(f" target_accept: {hmc_target_accept}")
print()
hmcsamples, acc_rate, final_step_size, step_size_history = uq.hmc_adaptive(
log_prob,
flat0.requires_grad_(True),
target_accept=hmc_target_accept,
initial_step_size=hmc_initial_step_size,
leapfrog_steps=hmc_leapfrog_steps,
num_samples=hmc_num_samples,
burn_in=hmc_burn_in,
adapt_steps=hmc_adapt_steps
)
print(f"\n{'='*60}")
print(f"Final Results:")
print(f" Acceptance rate: {acc_rate:.3f} ({acc_rate*100:.1f}%)")
print(f" Final step size: {final_step_size:.2e}")
print(f" Collected {len(hmcsamples)} samples")
print(f"{'='*60}")
Using 1000 test samples for HMC Model has 188041 parameters Initial parameter vector shape: torch.Size([188041]) Adaptive HMC Settings: num_samples: 2000 burn_in: 200 adapt_steps: 150 initial_step_size: 1e-07 leapfrog_steps: 20 target_accept: 0.75 Starting adaptive HMC with target acceptance rate: 75.00% Adaptation will run for 150 iterations Iter 50/2200: accept rate = 0.700, step_size = 1.78e-09, phase = adapting Iter 100/2200: accept rate = 0.730, step_size = 5.41e-09, phase = adapting Iter 150/2200: accept rate = 0.733, step_size = 9.77e-10, phase = adapting >>> Adaptation complete! Final step size: 1.64e-09 >>> Acceptance rate during adaptation: 0.728 Iter 200/2200: accept rate = 0.705, step_size = 1.64e-09, phase = burn-in Iter 250/2200: accept rate = 0.712, step_size = 1.64e-09, phase = sampling Iter 300/2200: accept rate = 0.727, step_size = 1.64e-09, phase = sampling Iter 350/2200: accept rate = 0.726, step_size = 1.64e-09, phase = sampling Iter 400/2200: accept rate = 0.735, step_size = 1.64e-09, phase = sampling Iter 450/2200: accept rate = 0.729, step_size = 1.64e-09, phase = sampling Iter 500/2200: accept rate = 0.728, step_size = 1.64e-09, phase = sampling Iter 550/2200: accept rate = 0.731, step_size = 1.64e-09, phase = sampling Iter 600/2200: accept rate = 0.742, step_size = 1.64e-09, phase = sampling Iter 650/2200: accept rate = 0.740, step_size = 1.64e-09, phase = sampling Iter 700/2200: accept rate = 0.747, step_size = 1.64e-09, phase = sampling Iter 750/2200: accept rate = 0.744, step_size = 1.64e-09, phase = sampling Iter 800/2200: accept rate = 0.748, step_size = 1.64e-09, phase = sampling Iter 850/2200: accept rate = 0.749, step_size = 1.64e-09, phase = sampling Iter 900/2200: accept rate = 0.753, step_size = 1.64e-09, phase = sampling Iter 950/2200: accept rate = 0.746, step_size = 1.64e-09, phase = sampling Iter 1000/2200: accept rate = 0.747, step_size = 1.64e-09, phase = sampling Iter 1050/2200: accept rate = 0.747, step_size = 1.64e-09, phase = sampling Iter 1100/2200: accept rate = 0.747, step_size = 1.64e-09, phase = sampling Iter 1150/2200: accept rate = 0.752, step_size = 1.64e-09, phase = sampling Iter 1200/2200: accept rate = 0.753, step_size = 1.64e-09, phase = sampling Iter 1250/2200: accept rate = 0.756, step_size = 1.64e-09, phase = sampling Iter 1300/2200: accept rate = 0.758, step_size = 1.64e-09, phase = sampling Iter 1350/2200: accept rate = 0.752, step_size = 1.64e-09, phase = sampling Iter 1400/2200: accept rate = 0.750, step_size = 1.64e-09, phase = sampling Iter 1450/2200: accept rate = 0.752, step_size = 1.64e-09, phase = sampling Iter 1500/2200: accept rate = 0.751, step_size = 1.64e-09, phase = sampling Iter 1550/2200: accept rate = 0.751, step_size = 1.64e-09, phase = sampling Iter 1600/2200: accept rate = 0.754, step_size = 1.64e-09, phase = sampling Iter 1650/2200: accept rate = 0.755, step_size = 1.64e-09, phase = sampling Iter 1700/2200: accept rate = 0.754, step_size = 1.64e-09, phase = sampling Iter 1750/2200: accept rate = 0.755, step_size = 1.64e-09, phase = sampling Iter 1800/2200: accept rate = 0.752, step_size = 1.64e-09, phase = sampling Iter 1850/2200: accept rate = 0.753, step_size = 1.64e-09, phase = sampling Iter 1900/2200: accept rate = 0.758, step_size = 1.64e-09, phase = sampling Iter 1950/2200: accept rate = 0.758, step_size = 1.64e-09, phase = sampling Iter 2000/2200: accept rate = 0.757, step_size = 1.64e-09, phase = sampling Iter 2050/2200: accept rate = 0.758, step_size = 1.64e-09, phase = sampling Iter 2100/2200: accept rate = 0.758, step_size = 1.64e-09, phase = sampling Iter 2150/2200: accept rate = 0.760, step_size = 1.64e-09, phase = sampling Iter 2200/2200: accept rate = 0.760, step_size = 1.64e-09, phase = sampling ============================================================ Final Results: Acceptance rate: 0.760 (76.0%) Final step size: 1.64e-09 Collected 2000 samples ============================================================
In [13]:
model = torch.load(model_save_file, weights_only=False)
model.to(device)
std1, result1 = uq.uqevaluation(num_test, test_data, model, 'hmc', hmcsamples=hmcsamples.clone())
model = torch.load(model_save_file, weights_only=False)
model.to(device)
uq.plot_uq(num_test, test_data, model, data, 'hmc', hmcsamples=hmcsamples.clone())
============================================================ Comprehensive Uncertainty Evaluation ============================================================ Evaluating uncertainty on 200 test samples... Computing posterior predictions... ============================================================ Uncertainty Quality Metrics ============================================================ 1. PREDICTION ACCURACY: RMSE: 0.115185 MAE: 0.082426 Mean Relative L2 Error: 13.67% Std Relative L2 Error: 5.00% 2. CALIBRATION (Coverage Analysis): Coverage within 1σ: 42.4% (ideal: 68.3%) Coverage within 2σ: 70.8% (ideal: 95.4%) Coverage within 3σ: 85.5% (ideal: 99.7%) Status: UNDER-CONFIDENT (uncertainties too large) 3. SHARPNESS (Uncertainty Magnitude): Mean Epistemic σ: 0.000019 Mean Total σ: 0.050000 Mean Aleatoric σ: 0.050000 (fixed) 4. UNCERTAINTY-ERROR CORRELATION: Pearson correlation: 0.520 → Good! High uncertainty correlates with high error 5. UNCERTAINTY DECOMPOSITION: Epistemic fraction: 0.0% Aleatoric fraction: 100.0% → Data noise dominates (model is confident) 6. PROPER SCORING RULES: Negative Log-Likelihood: 0.5767 (Lower is better) ============================================================ Predictions shape: (31, 5, 2601)
Monte-Carlo Dropout¶
In [15]:
import importlib
importlib.reload(uq)
model = torch.load(model_save_file, weights_only=False)
model.to(device)
# Ensure dropout is enabled
uq.inject_dropout(model)
torch.nn.Module.train(model)
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
torch.nn.Module.train(module)
In [16]:
std2, result2 = uq.uqevaluation(num_test, test_data, model, 'mcd')
uq.plot_uq(num_test, test_data, model, data, 'mcd')
Evaluating uncertainty on 200 test samples... Computing MC Dropout predictions... ============================================================ Uncertainty Quality Metrics ============================================================ 1. PREDICTION ACCURACY: RMSE: 0.309600 MAE: 0.181104 Mean Relative L2 Error: 32.22% Std Relative L2 Error: 6.74% 2. CALIBRATION (Coverage Analysis): Coverage within 1σ: 99.6% (ideal: 68.3%) Coverage within 2σ: 100.0% (ideal: 95.4%) Coverage within 3σ: 100.0% (ideal: 99.7%) Status: OVER-CONFIDENT (uncertainties too small) 3. SHARPNESS (Uncertainty Magnitude): Mean Epistemic σ: 0.748571 Mean Total σ: 0.750548 Mean Aleatoric σ: 0.050000 (fixed) 4. UNCERTAINTY-ERROR CORRELATION: Pearson correlation: 0.914 → Good! High uncertainty correlates with high error 5. UNCERTAINTY DECOMPOSITION: Epistemic fraction: 99.3% Aleatoric fraction: 0.7% → Model uncertainty dominates (more data may help) 6. PROPER SCORING RULES: Negative Log-Likelihood: 0.5770 (Lower is better) ============================================================ Predictions shape: (5, 5, 2601)
Laplacian Approximation¶
In [16]:
# Reload the original model (without dropout) for Laplace approximation
model = torch.load(model_save_file, weights_only=False)
model.to(device)
# Laplace approximation settings
noise_std_laplace = 0.05 # Assumed observation noise (same as HMC/MC Dropout for fair comparison)
prior_std_laplace = 1.0 # Prior standard deviation for weights
print(f"\nLaplace Approximation Settings:")
print(f" Assumed noise std: {noise_std_laplace}")
print(f" Prior std: {prior_std_laplace}")
# Use a subset of training data for computing the Hessian (full dataet may be too expensive)
num_laplace_data = min(500, num_train)
laplace_indices = np.random.choice(num_train, num_laplace_data, replace=False)
x_branch_laplace = train_data['X_train'][laplace_indices]
x_trunk_laplace = train_data['X_trunk']
y_laplace = train_data['Y_train'][laplace_indices]
print(f"Using {num_laplace_data} training samples for Hessian computation")
print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")
print("\nComputing diagonal Hessian approximation...")
H_diag = uq.compute_diagonal_hessian(
model, x_branch_laplace, x_trunk_laplace, y_laplace,
noise_std_laplace, prior_std_laplace, device, batch_size=10, sample_points_per_batch=30
)
# Posterior variance (diagonal approximation)
# σ²_posterior = 1 / H_diag
# IMPORTANT: Clip the posterior variance to prevent excessively large perturbations
# Parameters with very low Hessian (low curvature) would otherwise have huge variance,
# causing the model to produce wildly different outputs.
# We use a maximum posterior std of 0.01 (relative to typical weight magnitudes)
max_posterior_var = 0.01 ** 2 # Maximum variance corresponds to std of 0.01
posterior_var = 1.0 / (H_diag + 1e-8) # Raw posterior variance
posterior_var = torch.clamp(posterior_var, max=max_posterior_var) # Clip to prevent huge perturbations
print(f"\nHessian diagonal statistics:")
print(f" Mean H_diag: {H_diag.mean().item():.4e}")
print(f" Max H_diag: {H_diag.max().item():.4e}")
print(f" Min H_diag: {H_diag.min().item():.4e}")
print(f" Mean posterior σ (clipped): {torch.sqrt(posterior_var).mean().item():.4e}")
print(f" Max posterior σ (clipped): {torch.sqrt(posterior_var).max().item():.4e}")
# Clean up Hessian computation variables
del x_branch_laplace, x_trunk_laplace, y_laplace
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# ============================================================
# Sample from the Laplace posterior
# θ ~ N(θ_MAP, diag(1/H_diag))
# ============================================================
print("\nSampling from Laplace posterior...")
num_laplace_samples = 2000 # Number of posterior samples
theta_map = uq.pack_params(model).to(device)
# Compute posterior std once
posterior_std_vec = torch.sqrt(posterior_var)
# Sample from the posterior and store on CPU to save GPU memory
laplace_samples = []
for i in range(num_laplace_samples):
# Sample: θ = θ_MAP + σ_posterior * ε, where ε ~ N(0, I)
epsilon = torch.randn_like(theta_map)
theta_sample = theta_map + posterior_std_vec * epsilon
laplace_samples.append(theta_sample.cpu())
del epsilon, theta_sample
laplace_samples = torch.stack(laplace_samples)
print(f"Generated {num_laplace_samples} posterior samples")
# Free memory from Hessian computation
del H_diag, posterior_var, posterior_std_vec
torch.cuda.empty_cache() if torch.cuda.is_available() else None
Laplace Approximation Settings: Assumed noise std: 0.05 Prior std: 1.0 Using 500 training samples for Hessian computation Model has 188041 parameters Computing diagonal Hessian approximation... Processed 100/500 samples Processed 200/500 samples Processed 300/500 samples Processed 400/500 samples Processed 500/500 samples Hessian diagonal statistics: Mean H_diag: 1.5710e+06 Max H_diag: 4.7072e+09 Min H_diag: 1.0000e+00 Mean posterior σ (clipped): 3.6809e-03 Max posterior σ (clipped): 1.0000e-02 Sampling from Laplace posterior... Generated 2000 posterior samples
In [17]:
model = torch.load(model_save_file, weights_only=False)
model.to(device)
std3, result3 = uq.uqevaluation(num_test, test_data, model, 'la', lasamples=laplace_samples.clone())
model = torch.load(model_save_file, weights_only=False)
model.to(device)
uq.plot_uq(num_test, test_data, model, data, 'la', lasamples=laplace_samples.clone())
============================================================ Comprehensive Uncertainty Evaluation ============================================================ Evaluating uncertainty on 200 test samples... Computing Laplace posterior predictions... ============================================================ Uncertainty Quality Metrics ============================================================ 1. PREDICTION ACCURACY: RMSE: 0.114734 MAE: 0.081775 Mean Relative L2 Error: 13.99% Std Relative L2 Error: 5.24% 2. CALIBRATION (Coverage Analysis): Coverage within 1σ: 47.4% (ideal: 68.3%) Coverage within 2σ: 76.7% (ideal: 95.4%) Coverage within 3σ: 90.0% (ideal: 99.7%) Status: UNDER-CONFIDENT (uncertainties too large) 3. SHARPNESS (Uncertainty Magnitude): Mean Epistemic σ: 0.026829 Mean Total σ: 0.057163 Mean Aleatoric σ: 0.050000 (fixed) 4. UNCERTAINTY-ERROR CORRELATION: Pearson correlation: 0.623 → Good! High uncertainty correlates with high error 5. UNCERTAINTY DECOMPOSITION: Epistemic fraction: 22.4% Aleatoric fraction: 77.6% → Balanced uncertainty sources 6. PROPER SCORING RULES: Negative Log-Likelihood: -0.1035 (Lower is better) ============================================================ Predictions shape: (2000, 5, 2601)
Comparison¶
In [34]:
uq.comparison_uq(result1, result2, result3)
====================================================================== COMPARISON: HMC vs MC Dropout vs Laplace Approximation ====================================================================== Metric HMC MC Dropout Laplace Ideal ------------------------------------------------------------------------------------- RMSE 0.115185 0.287664 0.114734 Lower MAE 0.082426 0.174901 0.081775 Lower Mean Rel. L2 Error (%) 13.67 32.70 13.99 Lower Coverage 1σ (%) 42.4 99.6 47.4 68.3 Coverage 2σ (%) 70.8 100.0 76.7 95.4 Coverage 3σ (%) 85.5 100.0 90.0 99.7 Mean Epistemic σ 0.000019 0.707368 0.026829 - Mean Total σ 0.050000 0.709417 0.057163 - Epistemic Fraction (%) 0.0 99.3 22.4 - Uncertainty-Error Corr. 0.520 0.890 0.623 Higher NLL 0.5767 0.5347 -0.1035 Lower
Behaviour on OOD data¶
In [8]:
datas = DataProcessor(data_folder + data_prefix + '_samples_ood.npz', 3, 600, num_inp_fn_points, num_out_fn_points, num_Y_components)
data_ood = {'X_train': datas.X_test, 'X_trunk': datas.X_trunk, 'Y_train': datas.Y_test}
In [35]:
model = torch.load(model_save_file, weights_only=False)
model.to(device)
std4, result4 = uq.uqevaluation(600, data_ood, model, 'hmc', hmcsamples=hmcsamples.clone())
============================================================ Comprehensive Uncertainty Evaluation ============================================================ Evaluating uncertainty on 200 test samples... Computing posterior predictions... ============================================================ Uncertainty Quality Metrics ============================================================ 1. PREDICTION ACCURACY: RMSE: 29.997958 MAE: 17.350393 Mean Relative L2 Error: 173.69% Std Relative L2 Error: 110.90% 2. CALIBRATION (Coverage Analysis): Coverage within 1σ: 2.0% (ideal: 68.3%) Coverage within 2σ: 3.7% (ideal: 95.4%) Coverage within 3σ: 5.3% (ideal: 99.7%) Status: UNDER-CONFIDENT (uncertainties too large) 3. SHARPNESS (Uncertainty Magnitude): Mean Epistemic σ: 0.000541 Mean Total σ: 0.050008 Mean Aleatoric σ: 0.050000 (fixed) 4. UNCERTAINTY-ERROR CORRELATION: Pearson correlation: 0.999 → Good! High uncertainty correlates with high error 5. UNCERTAINTY DECOMPOSITION: Epistemic fraction: 0.0% Aleatoric fraction: 100.0% → Data noise dominates (model is confident) 6. PROPER SCORING RULES: Negative Log-Likelihood: 179783.0826 (Lower is better) ============================================================
In [36]:
model = torch.load(model_save_file, weights_only=False)
model.to(device)
uq.inject_dropout(model)
torch.nn.Module.train(model)
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
torch.nn.Module.train(module)
std5, result5 = uq.uqevaluation(600, data_ood, model, 'mcd')
============================================================ Comprehensive Uncertainty Evaluation ============================================================ Evaluating uncertainty on 200 test samples... Computing MC Dropout predictions... ============================================================ Uncertainty Quality Metrics ============================================================ 1. PREDICTION ACCURACY: RMSE: 30.600975 MAE: 17.816100 Mean Relative L2 Error: 176.09% Std Relative L2 Error: 109.56% 2. CALIBRATION (Coverage Analysis): Coverage within 1σ: 62.0% (ideal: 68.3%) Coverage within 2σ: 93.3% (ideal: 95.4%) Coverage within 3σ: 99.9% (ideal: 99.7%) Status: WELL-CALIBRATED 3. SHARPNESS (Uncertainty Magnitude): Mean Epistemic σ: 19.316290 Mean Total σ: 19.316795 Mean Aleatoric σ: 0.050000 (fixed) 4. UNCERTAINTY-ERROR CORRELATION: Pearson correlation: 0.997 → Good! High uncertainty correlates with high error 5. UNCERTAINTY DECOMPOSITION: Epistemic fraction: 99.9% Aleatoric fraction: 0.1% → Model uncertainty dominates (more data may help) 6. PROPER SCORING RULES: Negative Log-Likelihood: 3.4872 (Lower is better) ============================================================
In [22]:
model = torch.load(model_save_file, weights_only=False)
model.to(device)
std6, result6 = uq.uqevaluation(600, data_ood, model, 'la', lasamples=laplace_samples.clone())
============================================================ Comprehensive Uncertainty Evaluation ============================================================ Evaluating uncertainty on 200 test samples... Computing Laplace posterior predictions... ============================================================ Uncertainty Quality Metrics ============================================================ 1. PREDICTION ACCURACY: RMSE: 31.981734 MAE: 19.142937 Mean Relative L2 Error: 178.93% Std Relative L2 Error: 117.28% 2. CALIBRATION (Coverage Analysis): Coverage within 1σ: 3.4% (ideal: 68.3%) Coverage within 2σ: 6.7% (ideal: 95.4%) Coverage within 3σ: 10.2% (ideal: 99.7%) Status: UNDER-CONFIDENT (uncertainties too large) 3. SHARPNESS (Uncertainty Magnitude): Mean Epistemic σ: 0.502801 Mean Total σ: 0.512050 Mean Aleatoric σ: 0.050000 (fixed) 4. UNCERTAINTY-ERROR CORRELATION: Pearson correlation: 0.996 → Good! High uncertainty correlates with high error 5. UNCERTAINTY DECOMPOSITION: Epistemic fraction: 81.8% Aleatoric fraction: 18.2% → Model uncertainty dominates (more data may help) 6. PROPER SCORING RULES: Negative Log-Likelihood: 571.6778 (Lower is better) ============================================================
In [37]:
uq.comparison_uq(result4, result5, result6)
====================================================================== COMPARISON: HMC vs MC Dropout vs Laplace Approximation ====================================================================== Metric HMC MC Dropout Laplace Ideal ------------------------------------------------------------------------------------- RMSE 29.997958 30.600975 31.981734 Lower MAE 17.350393 17.816100 19.142937 Lower Mean Rel. L2 Error (%) 173.69 176.09 178.93 Lower Coverage 1σ (%) 2.0 62.0 3.4 68.3 Coverage 2σ (%) 3.7 93.3 6.7 95.4 Coverage 3σ (%) 5.3 99.9 10.2 99.7 Mean Epistemic σ 0.000541 19.316290 0.502801 - Mean Total σ 0.050008 19.316795 0.512050 - Epistemic Fraction (%) 0.0 99.9 81.8 - Uncertainty-Error Corr. 0.999 0.997 0.996 Higher NLL 179783.0826 3.4872 571.6778 Lower
ood data detection¶
In [38]:
hmc_ood_eval = np.concatenate((std1.mean(axis=1),std4.mean(axis=1)), axis=0) # epistemic + aleatoric
hmc_ood_eval = hmc_ood_eval/np.max(hmc_ood_eval)
mcd_ood_eval = np.concatenate((std2.mean(axis=1),std5.mean(axis=1)), axis=0)
mcd_ood_eval = mcd_ood_eval/np.max(mcd_ood_eval)
la_ood_eval = np.concatenate((std3.mean(axis=1),std6.mean(axis=1)), axis=0)
la_ood_eval = la_ood_eval/np.max(la_ood_eval)
oods = np.concatenate((np.zeros(std1.shape[0]), np.ones(std4.shape[0])), axis=0) # 0 for ID, 1 for OOD
#Examine OOD data:
# Step 1: Generate uncertainty scores: ood_eval
# Step 2: Create true labels for AUROC (1 for OOD, 0 for ID): oods
# Step 3: Calculate AUROC
for ood, uqmethod in zip([hmc_ood_eval, mcd_ood_eval, la_ood_eval], ['HMC', 'MC Dropout', 'Laplace Approximation']):
auroc = roc_auc_score(oods, ood)
fpr, tpr, _ = roc_curve(oods, ood)
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUROC = {auroc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(f'Receiver Operating Characteristic (ROC) using {uqmethod}')
plt.legend(loc='lower right')
plt.show()
print(f"AUROC for OOD detection: {auroc}")
AUROC for OOD detection: 0.9461124999999999
AUROC for OOD detection: 0.94215
AUROC for OOD detection: 0.9379750000000001
Conformal Prediction¶
In [9]:
model = torch.load(model_save_file, weights_only=False)
model.to(device)
alpha = 0.1 # Target error rate (90% coverage)
n_cal = 500 # Number of samples for calibration
# Split remaining test data into calibration and final test
indices = np.arange(num_test)
np.random.shuffle(indices)
cal_idx = indices[:n_cal]
test_idx = indices[n_cal:]
# Get predictions for calibration set
x_branch_cal = test_data['X_train'][cal_idx]
x_trunk_cal = test_data['X_trunk']
y_cal = test_data['Y_train'][cal_idx]
with torch.no_grad():
x_b = torch.from_numpy(x_branch_cal).float().to(device)
x_t = torch.from_numpy(x_trunk_cal).float().to(device)
y_cal_pred = model.predict(x_b, x_t).cpu().numpy()
# Calculate non-conformity scores (Absolute Residuals)
# For vector outputs, we use point-wise absolute error
scores = np.abs(y_cal - y_cal_pred)
# Compute the (1-alpha) quantile of the scores
q_level = np.ceil((n_cal + 1) * (1 - alpha)) / n_cal
qhat = np.quantile(scores, q_level, axis=0) # Point-wise quantile
print(f"Calibration complete using {n_cal} samples.")
print(f"Quantile (qhat) mean value: {qhat.mean():.4f}")
Calibration complete using 500 samples. Quantile (qhat) mean value: 0.1694
In [10]:
# Evaluate on the remaining test set
x_branch_final = np.concatenate((test_data['X_train'][test_idx], data_ood['X_train']), axis=0)
y_test_final = np.concatenate((test_data['Y_train'][test_idx], data_ood['Y_train']), axis=0)
with torch.no_grad():
x_b_final = torch.from_numpy(x_branch_final).float().to(device)
y_test_pred = model.predict(x_b_final, x_t).cpu().numpy()
# Prediction Intervals: [pred - qhat, pred + qhat]
lower_bound = y_test_pred - qhat
upper_bound = y_test_pred + qhat
# Calculate empirical coverage
cover = ((y_test_final >= lower_bound) & (y_test_final <= upper_bound)).mean(axis=1)
coverage = cover.mean()
cover_id = cover[:len(test_idx)]
cover_ood = cover[len(test_idx):]
cover_id_mean = cover_id.mean()
cover_ood_mean = cover_ood.mean()
print(f"Results for alpha = {alpha} (Target Coverage: {1-alpha:.1%})")
print(f"Empirical Test Coverage: {coverage:.1%}")
print(f"ID Coverage: {cover_id_mean:.1%}")
print(f"OOD Coverage: {cover_ood_mean:.1%}")
print(f"Average Prediction Interval Width: {(upper_bound - lower_bound).mean():.4f}")
Results for alpha = 0.1 (Target Coverage: 90.0%) Empirical Test Coverage: 44.3% ID Coverage: 89.8% OOD Coverage: 6.4% Average Prediction Interval Width: 0.3388
In [11]:
oods = np.concatenate((np.ones(len(test_idx)), np.zeros(len(data_ood['X_train']))), axis=0)
#Examine OOD data:
# Step 1: Generate uncertainty scores: ood_eval
# Step 2: Create true labels for AUROC (1 for OOD, 0 for ID): oods
# Step 3: Calculate AUROC
auroc = roc_auc_score(oods, cover)
fpr, tpr, _ = roc_curve(oods, cover)
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUROC = {auroc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc='lower right')
plt.show()
print(f"AUROC for OOD detection: {auroc}")
AUROC for OOD detection: 0.9999266666666667
In [ ]: